Skip to main content

Tree

  • Tree problems usually just spin around DFS or BFS and utilizes some specific tree properties so it's important to be able to implement these 2 algo while blindfolded

Overview

Traversal

DFS

  • The idea is usually to use recursion and recursively traverse the left and right subtree while performing some operation

  • Sometimes it's helpful to use another params in the recursive function to keep track of some data (eg: path sum, current node path, etc)

  • Preorder Traversal

def preorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.preorderTraversal(root.left)
right = self.preorderTraversal(root.right)

return [root.val] + left + right
  • Inorder Traversal
def inorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.inorderTraversal(root.left)
right = self.inorderTraversal(root.right)

return left + [root.val]+ right
  • Postorder Traversal
def postorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

left = self.postorderTraversal(root.left)
right = self.postorderTraversal(root.right)

return left + right + [root.val]
def buildTree(self, preorder: List[int], inorder: List[int]) -> Optional[TreeNode]:
if not preorder or not inorder:
return None

root = TreeNode(preorder[0])
# We can use rootIndex of inorder because
# preorder: root + left + right
# inorder: left + root + right
# So we can use the rootIndex to split the preorder and inoder to 2 subarrays
rootIndex = inorder.index(root.val)
root.left = self.buildTree(preorder[1:rootIndex+1], inorder[:rootIndex+1])
root.right = self.buildTree(preorder[rootIndex+1:], inorder[rootIndex+1:])

return root
# https://leetcode.com/problems/path-sum/
def hasPathSum(self, root: Optional[TreeNode], targetSum: int) -> bool:
def getSum(node, curSum):
if not node:
return False
if not node.left and not node.right and curSum + node.val == targetSum:
return True

left = getSum(node.left, curSum + node.val)
right = getSum(node.right, curSum + node.val)

return left or right

return getSum(root, 0)

https://leetcode.com/problems/boundary-of-binary-tree

  • Actual implementation is not bad, good practice for DFS thinking
  • The hard part is understand WTF the description is about
 def boundaryOfBinaryTree(self, root: Optional[TreeNode]) -> List[int]:
if not root:
return []

def getLeft(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.left:
return current + getLeft(node.left)
else:
return current + getLeft(node.right)

def getRight(node):
if not node:
return []
if not node.left and not node.right:
return []
current = [node.val]
if node.right:
return getRight(node.right) + current
else:
return getRight(node.left) + current

leaves = []
def getLeaves(node):
if not node:
return
if not node.left and not node.right:
leaves.append(node.val)
return

getLeaves(node.left)
getLeaves(node.right)

left = getLeft(root.left)
current = [root.val]
right = getRight(root.right)
if root.left or root.right:
getLeaves(root)

return current + left + leaves + right

BFS

  • The idea is to use queue to keep track of what to traverse next
  • Use when
    • Level order traversal
      • When problem asks about relationships between nodes at the same height/depth
      • When you need to find the minimum depth or shortest path
    • Parent relationship
      • When you need to keep track of what is the parent of current node (any maybe compare to other node of the same level)
# https://leetcode.com/problems/binary-tree-level-order-traversal/
def levelOrder(self, root: Optional[TreeNode]) -> List[List[int]]:
res = []
if not root:
return res
queue = [root]

while queue:
cur_level = []
cur_len = len(queue)
for i in range(cur_len):
node = queue.pop(0)
cur_level.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
res.append(cur_level)

return res

https://leetcode.com/problems/maximum-level-sum-of-a-binary-tree

 def maxLevelSum(self, root: Optional[TreeNode]) -> int:
maxVal = float('-inf')
maxLevel = -1

queue = [(root)]
level = 1

while queue:
curLen = len(queue)
curSum = 0
for _ in range(curLen):
node = queue.pop(0)
curSum += node.val
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
if curSum > maxVal:
maxVal = curSum
maxLevel = level
level += 1

return maxLevel

https://leetcode.com/problems/cousins-in-binary-tree

  def isCousins(self, root: Optional[TreeNode], x: int, y: int) -> bool:
if not root:
return False

queue = [(root, None, 0)]
xDepth, xParent = -1, None
yDepth, yParent = -1, None

while queue:
curLen = len(queue)
for i in range(curLen):
node, parent, depth = queue.pop(0)
if node.val == x:
xParent, xDepth = parent, depth
elif node.val == y:
yParent, yDepth = parent, depth

if xParent and yParent:
break

if node.left:
queue.append((node.left, node, depth + 1))
if node.right:
queue.append((node.right, node, depth + 1))

return xDepth == yDepth and xParent != yParent

https://leetcode.com/problems/cousins-in-binary-tree-ii

  • More advanced version here
  • The trick here is that we have to optimize the calculation of every other nodes which do not share the same parent
    • Pre-computation: node.val = total - node.val - sum(cousins) => We can pre-calculate this
    • HashTable
  def replaceValueInTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
if not root:
return
queue = [(root, None)]

while queue:
curLen = len(queue)
level = []
for _ in range(curLen):
node, parent = queue.pop(0)
level.append((node, parent))
if node.left:
queue.append((node.left, node))
if node.right:
queue.append((node.right, node))

negate = {}
for i in range(len(level)):
node, parent = level[i]
negate[i] = node.val
if i > 0 and parent == level[i-1][1]:
negate[i-1] += node.val
negate[i] += level[i-1][0].val

total = sum([node.val for node, _ in level])

for i in range(len(level)):
node, _ = level[i]
node.val = total - negate[i]

return root

Problems

Variants

Binary Search Tree

  • Definition: A tree where each node has at most two children, and for any node

    • All values in the left subtree are less than the node's value
    • All values in the right subtree are greater than the node's value
    • In-order traversal yields sorted order
  • Efficient for search, insert, and delete operations when balanced: O(log n)

  • These are some of the fundamental problems that demonstrate BST properties

def sortedArrayToBST(self, nums: List[int]) -> Optional[TreeNode]:
if not nums:
return

mid = len(nums) // 2
root = TreeNode(nums[mid])
root.left = self.sortedArrayToBST(nums[:mid])
root.right = self.sortedArrayToBST(nums[mid+1:])

return root

https://leetcode.com/problems/construct-binary-search-tree-from-preorder-traversal

  • Good question to notice BST properties
def bstFromPreorder(self, preorder: List[int]) -> Optional[TreeNode]:
if not preorder:
return
root = TreeNode(preorder[0])
left, right = 1, len(preorder) - 1
while left <= right:
mid = left + (right - left) // 2
if preorder[mid] > root.val:
right = mid-1
else:
left = mid + 1

leftHalf = self.bstFromPreorder(preorder[1:left])
rightHalf = self.bstFromPreorder(preorder[left:])

root.left = leftHalf
root.right = rightHalf

return root

https://leetcode.com/problems/validate-binary-search-tree

     def isValidBST(self, root: Optional[TreeNode]) -> bool:
def helper(node, lower, upper):
if not node:
return True
if node.val >= upper or node.val <= lower:
return False

return helper(node.left, lower, node.val) and helper(node.right, node.val, upper)

return helper(root, float('-inf'), float('inf'))

https://leetcode.com/problems/insert-into-a-binary-search-tree

 def insertIntoBST(self, root: Optional[TreeNode], val: int) -> Optional[TreeNode]:
if not root:
return TreeNode(val)

if val > root.val:
root.right = self.insertIntoBST(root.right, val)

if val < root.val:
root.left = self.insertIntoBST(root.left, val)

return root

https://leetcode.com/problems/delete-node-in-a-bst

  • Important algorithm to memorize
  • 3 steps:
    1. Find the node to remove
    2. Replace the node by its successor (either smallest in right subtree or largest in left subtree)
    3. Remove the successor
def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
def findSuccessor(node):
if node.left:
return findSuccessor(node.left)
return node

if not root:
return

if key > root.val:
root.right = self.deleteNode(root.right, key)
elif key < root.val:
root.left = self.deleteNode(root.left, key)
else:
if not root.left and not root.right:
return None
elif not root.left:
return root.right
elif not root.right:
return root.left
else:
successor = findSuccessor(root.right)
root.val = successor.val
root.right = self.deleteNode(root.right, root.val)

return root

Lowest Common Ancestor

https://leetcode.com/discuss/interview-question/6024811

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree

  def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if not root or root.val == p.val or root.val == q.val:
return root

left = self.lowestCommonAncestor(root.left, p, q)
right = self.lowestCommonAncestor(root.right, p, q)

if left and right:
return root

return left or right

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-ii

  • Use the helper/dfs function to both find the common ancestor and the existence of each node
  • Maintain 2 variables to keep track of the existence of each node
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
foundP = False
foundQ = False

def helper(node):
nonlocal foundP
nonlocal foundQ
if not node:
return

left = helper(node.left)
right = helper(node.right)

if node.val == p.val:
foundP = True
return node
if node.val == q.val:
foundQ = True
return node

if left and right:
return node
return left or right

res = helper(root)
return res if foundP and foundQ else None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-search-tree

# Recursive
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
if root.val == q.val or root.val == p.val:
return root
if (p.val < root.val < q.val) or (q.val < root.val < p.val):
return root

if p.val < root.val and q.val < root.val:
return self.lowestCommonAncestor(root.left, p, q)

if p.val > root.val and q.val > root.val:
return self.lowestCommonAncestor(root.right, p, q)

# Iterative
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
node = root
while node:
if node.val > p.val and node.val > q.val:
node = node.left
elif node.val < p. val and node.val < q.val:
node = node.right
elif (node.val >= p.val and node.val <= q.val) or (node.val <= p.val and node.val >= q.val):
return node

return None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iii

  • Traverse 1 node all the way to the root
  • Traverse the other one, the first node that we meet that is visited by the previous traversal is the LCA
  def lowestCommonAncestor(self, p: 'Node', q: 'Node') -> 'Node':
visit = set()
while p:
visit.add(p)
p = p.parent

while q:
if q in visit:
return q
visit.add(q)
q = q.parent

return None

https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iv

def lowestCommonAncestor(self, root: 'TreeNode', nodes: 'List[TreeNode]') -> 'TreeNode':
nodes = set(nodes)

def dfs(node):
if not node or node in nodes:
return node
left = dfs(node.left)
right = dfs(node.right)

if left and right:
return node
return left or right

return dfs(root)

https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves

   def lcaDeepestLeaves(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
def depth(node):
if not node:
return 0, node

leftDepth, lcaLeft = depth(node.left)
rightDepth, lcaRight = depth(node.right)

if leftDepth == rightDepth:
return leftDepth + 1, node
elif leftDepth > rightDepth:
return leftDepth + 1, lcaLeft
elif leftDepth < rightDepth:
return rightDepth + 1, lcaRight

return depth(root)[1]

Tree Diameter

https://leetcode.com/problems/diameter-of-binary-tree

def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
res = 0

def dfs(node):
nonlocal res
if not node:
return 0
left = dfs(node.left)
right = dfs(node.right)

res = max(res, left + right)

return 1 + max(left, right)

dfs(root)
return res

https://leetcode.com/problems/diameter-of-n-ary-tree

  • Same idea as before, just need to take sum of 2 longest path
 def diameter(self, root: 'Node') -> int:
res = 0

def dfs(node):
nonlocal res
if not node:
return
childs = []
maxChild, secondMaxChild = 0, 0
for child in node.children:
current = dfs(child)
if current >= maxChild:
secondMaxChild = maxChild
maxChild = current
elif secondMaxChild <= current < maxChild:
secondMaxChild = current
childs.append(current)

res = max(res, maxChild + secondMaxChild)

return 1 + max(childs) if childs else 1

dfs(root)
return res

Patterns

Tree as Graph

  • Sometimes a problem is given as a tree, but we actually want to treat (maybe convert) the tree as a [[9. Graph| graph]].
  • In a typical tree problem, we only need to
    • Move from parent to children (downward)
    • Process nodes in a specific order
    • Track information along a single path
  • These signals/characteristics in a problem should make you consider transforming the tree into a graph:

Non standard movement requirements

https://leetcode.com/problems/all-nodes-distance-k-in-binary-tree

def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> List[int]:
def buildGraph(node, parent, graph):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node, graph)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node, graph)

graph = collections.defaultdict(list)
buildGraph(root, None, graph)

visit = set()
res = []
queue = [(target.val, 0)]
visit.add(target.val)

while queue:
node, dist = queue.pop(0)
if dist > k:
continue
elif dist == k:
res.append(node)
else:
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dist+1))
visit.add(nb)

return res

https://leetcode.com/problems/amount-of-time-for-binary-tree-to-be-infected/

def amountOfTime(self, root: Optional[TreeNode], start: int) -> int:
graph = collections.defaultdict(list)

def buildGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
buildGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
buildGraph(node.right, node)

buildGraph(root, None)
infected = set()
infected.add(start)
queue = [(start, 0)]
res = 0

while queue:
curLen = len(queue)
for i in range(curLen):
node, dist = queue.pop(0)
infected.add(node)
res = max(res, dist)
for nb in graph[node]:
if nb not in infected:
queue.append((nb, dist+1))

return res

https://leetcode.com/problems/step-by-step-directions-from-a-binary-tree-node-to-another


https://leetcode.com/problems/closest-leaf-in-a-binary-tree

  def findClosestLeaf(self, root: Optional[TreeNode], k: int) -> int:
graph = collections.defaultdict(list)
leaf = set()

def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if not node.left and not node.right:
leaf.add(node.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)

builGraph(root, None)
visit = set()
queue = [k]
visit.add(k)

while queue:
curLen = len(queue)
for _ in range(curLen):
curNode = queue.pop(0)
if curNode in leaf:
return curNode
for nb in graph[curNode]:
if nb not in visit:
visit.add(nb)
queue.append(nb)

return -1

Relationship-base queries

  • When a problem asks about relationships that aren't purely hierarchal, consider a graph
    • Finding nodes at a specific distance
    • Finding the distance between any two nodes
    • Finding all nodes that can reach a target node
    • Finding the shortest path between nodes

https://leetcode.com/problems/find-distance-in-a-binary-tree

def findDistance(self, root: Optional[TreeNode], p: int, q: int) -> int:
graph = collections.defaultdict(list)

def builGraph(node, parent):
if not node:
return
if parent:
graph[node.val].append(parent.val)
if node.left:
graph[node.val].append(node.left.val)
builGraph(node.left, node)
if node.right:
graph[node.val].append(node.right.val)
builGraph(node.right, node)

builGraph(root, None)
visit = set()
queue = [(p, 0)]
visit.add(p)

while queue:
curLen = len(queue)
for _ in range(curLen):
node, dst = queue.pop(0)
visit.add(node)
if node == q:
return dst
for nb in graph[node]:
if nb not in visit:
queue.append((nb, dst + 1))

return -1

https://leetcode.com/problems/binary-tree-maximum-path-sum


Parent Access need

  • If you find yourself thinking "I need to know this node's parent" or "I need to move upward form this node", that's often a signal that a graph representation might be helpful.

https://leetcode.com/problems/find-all-the-lonely-nodes


https://leetcode.com/problems/find-nearest-right-node-in-binary-tree


https://leetcode.com/problems/lowest-common-ancestor-of-deepest-leaves


Misc

NOTES

  • An important property when dealing with Full Binary Tree is that it only has 2 childrens, from this we can calculate that:

    Binary Tree in Array Representation If a binary tree is represented as an array:

    1. **Index 0** represents the root node.
    2. For a node at index i:
    • **Left Child**: The left child is located at index 2i + 1.
    • **Right Child**: The right child is located at index 2i + 2.